from datasets import Dataset
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader

num_samples = 1024

X = np.linspace(-1, 1, num_samples)
Y = 0.5 * X + 2

df = pd.DataFrame({"X": X, "Y": Y})

toy_dataset = Dataset.from_pandas(df)


def collate_fn(batch):
    import os
    print(f"collate function pid : {os.getpid()}", f"collate function ppid : {os.getppid()}", )
    return batch


toy_dataloader = DataLoader(toy_dataset,
                            batch_size=4,
                            shuffle=True,
                            pin_memory=True,
                            num_workers=4, )
